from typing import Union

from fastapi import FastAPI, File, UploadFile , Form

app = FastAPI()
import os
# from llama_index import VectorStoreIndex, SimpleDirectoryReader, LLMPredictor, ServiceContext
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex, Document
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core import QueryBundle
import pandas as pd


from copy import deepcopy
from langchain_community.llms.baichuan import BaichuanLLM
from langchain_community.embeddings import BaichuanTextEmbeddings
from llama_index.core.postprocessor import LLMRerank


from llama_index.vector_stores.milvus import MilvusVectorStore

from pymilvus import connections, Collection


API_KEY = 'XXXXXXx输入api'

Settings.llm = BaichuanLLM(baichuan_api_key=API_KEY)
Settings.chunk_overlap = 100
Settings.chunk_size = 600
Settings.embed_model = BaichuanTextEmbeddings(baichuan_api_key=API_KEY)
DIMENSION = 1024

URI = "http://localhost:19530"
# input_dir = 'E:\\work\\vector\\txt'
# 加载和索引文档
def create_mulivus_collection( collectionName, URI):
    '''
    创建集合
    collectionName:milvus集合名称
    URI:milvus地址,如果用docker, URI = "http://localhost:19530" 本地化默认 URI = "./milvus_llamaindex.db"
    返回索引 index
    '''
    # llm = BaichuanLLM(baichuan_api_key=API_KEY)

    # documents = SimpleDirectoryReader(input_dir=input_dir).load_data()
    #SimpleDirectoryReader 具体请看：https://docs.llamaindex.ai/en/stable/module_guides/loading/simpledirectoryreader/
    # vector_store = MilvusVectorStore(uri=URI ,dim=1024,collection_name=collectionName,  overwrite=True)

    # print(f'Loaded {len(documents)} documents.')

    vector_store = MilvusVectorStore(
        uri=URI, 
        dim=DIMENSION, 
        collection_name=collectionName, 
        # overwrite=True,
    )
    # return vector_store
        # overwrite=True, 是否重写
    print('Created MilvusVectorStore.')
    return vector_store
    # vector_store.add()
    # index = VectorStoreIndex.from_documents(documents, vector_store=vector_store, show_progress=True)
    # storage_context = StorageContext.from_defaults(vector_store=vector_store)

    # 检查集合是否为空
    # return index
def get_index_from_collection(vector_store):
    '''
    vector_store:向量存储
    从现有的向量存储中加载索引
    return index
    '''
    # 使用 from_vector_store 方法从现有的向量存储中加载索引
    #  Document(text=content.decode('utf-8'))
    index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
    print('Loaded VectorStoreIndex from collection.')
    return index


def query_question(index,question):
    '''
    查询问题
    index:索引
    question:问题
    return 查询结果
    '''
    query_engine = index.as_query_engine()
    res = query_engine.query(question)
    print(res)
    return res
def chatllm_query_quetion(aicubellm, index, question):
    '''
    查询问题
    index:索引
    question:问题
    return 查询结果
    '''
    # query_engine = index.as_query_engine()
    # res = query_engine.query(question)
    # print(res)
    # return res

    chat_engine = index.as_chat_engine( chat_mode="openai", llm=aicubellm, verbose=True)
    response = chat_engine.chat(question)
    print(response)
    return response


# from llama_index.core import Document

async def create_doc_from_content_byfileName(content: bytes, metadata: dict = {}) -> Document:
    """Creates a document from file content with additional metadata."""
    # 解码内容为字符串
    # text = content.decode('utf-8')
    
    text = content

    
    # 创建文档，并添加元数据
    doc = Document(text=text, extra_info=metadata)
    
    return doc
async def insert_file_to_knowledge_base_with_filename(collectionName: str, file: UploadFile = File(...)):
    try:
        # 读取文件内容
        content = await file.read()

        # 创建文档，并添加 filename 作为元数据的一部分
        doc = await create_doc_from_content_byfileName(content, metadata={"file_name": file.filename})
        print(doc)
        
        # 构建或更新索引
        vector_store = MilvusVectorStore(uri=URI, dim=DIMENSION, collection_name=collectionName)
        index = get_index_from_collection(vector_store) 

        # 插入文档，确保 filename 也被存储
        index.insert(doc)

        return f"Insert '{file.filename}' successfully."
    except Exception as e:
        return {"error": str(e)}
async def create_doc_from_content(content: bytes) -> Document:
    """Creates a document from file content."""
    # 这里可以根据你的需求自定义文档的创建逻辑
    # 假设文档包含一个字段 'text'
    # print(content.decode('utf-8'))
    return Document(text=content.decode('utf-8'))



async def insert_file_to_knowledge_base_wo_save(collectionName: str, file: UploadFile = File(...)):
    try:
        # 读取文件内容
        content = await file.read()
        # 创建文档
        doc = await create_doc_from_content(content)
        print(doc)
        # 构建或更新索引
        vector_store = MilvusVectorStore(uri=URI, dim=DIMENSION, collection_name=collectionName)
        index = get_index_from_collection(vector_store) 

        index.insert(doc)

        return f"Insert '{file.filename}' successfully."
    except Exception as e:
        return {"error": str(e)}





# 更新知识库
def update_knowledge_base(index,input_file):
    '''
    目的：通过添加新文档来更新知识库。
    参数：
    index：要更新的索引。
    input_file：新文档文件的路径。
    过程：
    加载新文档并将其插入索引。
    SimpleDirectoryReader可以接受哪些文件格式具体请见官网
    SimpleDirectoryReader: https://docs.llamaindex.ai/en/stable/module_guides/loading/simpledirectoryreader/
    '''

    new_docs = SimpleDirectoryReader(input_files=[input_file]).load_data()
    index.insert(new_docs[0])


def get_retrieved_nodes(
    query_str, vector_top_k=10, reranker_top_n=3, with_reranker=False
):
    '''
    目的：根据查询字符串检索节点。
    参数：
    query_str：查询字符串。
    vector_top_k：要检索的节点数量。
    reranker_top_n：重新排名后节点的数量。
    with_reranker：是否使用重新排名。
    过程：
    使用向量索引检索器检索节点。
    可选地重新排名节点。
    '''
    query_bundle = QueryBundle(query_str)
    # configure retriever
    retriever = VectorIndexRetriever(
        index=index,
        similarity_top_k=vector_top_k,
    )
    retrieved_nodes = retriever.retrieve(query_bundle)

    if with_reranker:
        # configure reranker
        reranker = LLMRerank(
            choice_batch_size=5,
            top_n=reranker_top_n,
        )
        retrieved_nodes = reranker.postprocess_nodes(
            retrieved_nodes, query_bundle
        )

    return retrieved_nodes

def visualize_retrieved_nodes(nodes) -> None:
    '''
    目的：将检索到的节点可视化为Pandas DataFrame。
    参数：
    nodes：检索到的节点列表。
    过程：
    将节点数据转换为DataFrame并打印。
    '''
    result_dicts = []
    for node in nodes:
        node = deepcopy(node)
        # node.node.metadata = None
        node_text = node.node.get_text()
        node_text = node_text.replace("\n", " ")

        result_dict = {"Score": node.score, "Text": node_text}
        result_dicts.append(result_dict)
    html_content = pd.DataFrame(result_dicts).to_html().replace("\\n", "")
    print(html_content)

def print_index_struct(index):
    ''''
    目的：打印索引的结构。
    参数：
    index：要打印结构的索引。
    过程：
    打印索引的内部结构。
    '''
    index_struct = index._index_struct
    # index_struct = index.get_index_struct()
    print(index_struct)

def print_node(index):
    '''
    目的：打印索引中节点的数量。
    参数：
    index：要检查的索引。
    过程：
    计算并打印索引中的节点数量。
    '''
    # for node_id in index.docstore.docs.keys():
    #     # node = index.docstore.get_node(node_id)
    #     print(node_id)
        
    node_count = len(index.docstore.docs)
    print(f"Number of nodes in the index: {node_count}")


def delete_collection(collection_name):
    '''
    目的：删除Milvus集合。
    参数：
    collection_name：要删除的集合名称。
    过程：
    尝试删除指定的集合。
    '''
    try:
        connections.connect("default")
        # 尝试获取集合，如果集合不存在将抛出异常
        collection = Collection(name=collection_name)
        # 删除集合
        collection.drop()
        print(f"Collection '{collection_name}' has been deleted.")
    except Exception as e:
        if "does not exist" in str(e):
            print(f"Collection '{collection_name}' does not exist.")
        else:
            print(f"An error occurred while deleting the collection: {e}")
# 调用函数并打印结果

def print_entity_count(collection_name):
    '''
    目的：打印Milvus集合中的实体数量。
    参数：
    collection_name：集合的名称。
    过程：
    打印指定集合的实体计数。
    '''
    try:
        connections.connect("default")
        collection = Collection(collection_name)
        print(f"Number of entities in collection '{collection_name}': {collection.num_entities}")
    except Exception as e:
        print(f"An error occurred: {e}")

def index_persist(index):
    '''
    目的：将索引持久化到磁盘。
    参数：
    index：要持久化的索引。
    过程：
    将索引持久化到磁盘，并从存储上下文重新加载。
    '''
    # 默认情况下，数据存储在内存中。 要持久化到磁盘（在以下位置）：./storage
    index.storage_context.persist()

    # 要从磁盘重新加载：
    # rebuild storage context
    storage_context = StorageContext.from_defaults(persist_dir="./storage")

    # load index
    index = load_index_from_storage(storage_context)




@app.post("/create_vector_store/")
async def create_vector_store( collectionName:str):
    '''
    创建集合
    collectionName：集合名称
    '''
    try:
        # connections.connect("default")
        # 创建集合
        index = create_mulivus_collection( collectionName, URI)  
        return {"message": f"Vector store '{collectionName}' creat"}
    except Exception as e:
        return {"error": str(e)}
    # index = create_mulivus_collection( collectionName, URI) 
    # return {"message": "Vector store creat"}



@app.post("/query_question_from_collection/")
async def query_question_from_collection(collectionName: str, question: str):
    '''
    查询问题
    input：collectionName: str 集合名称, question: str 查询相关问题
    '''
    # input_dir = input_data.input_dir
    vector_store = MilvusVectorStore(uri=URI, collection_name=collectionName)
    # index = get_index_from_collection(vector_store) 
    index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
    return query_question(index, question)

@app.post("/insert_file_to_KnowledgeBase_wo_save/")
async def insert_file_to_KnowledgeBase_wo_save(collectionName: str, file: UploadFile = File(None), text: str = Form(None), textFileName: str = Form(None)):
    '''
    直接插入文件或者数据，没有保存到本机
    '''
    try:
        
        if file is not None:
            # 上传文件的情况
            return await insert_file_to_knowledge_base_with_filename(collectionName, file)
        elif text is not None and textFileName is not None:
            vector_store = MilvusVectorStore(uri=URI, dim=DIMENSION, collection_name=collectionName)
            index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
            # 输入文本的情况
            doc = await create_doc_from_content_byfileName(text, metadata={"file_name": textFileName})
            index.insert(doc)
            return {"message": f"Insert '{file.filename if file is not None else textFileName}' successfully."}
        else:
            return {"error": "No file or text provided."}
    except Exception as e:
        return {"error": str(e)}







@app.post("/insert_data_to_KnowledgeBase/")
async def insert_data_to_KnowledgeBasee(collectionName: str, file: UploadFile = File(None), text: str = Form(None), textFileName: str = Form(None)):
    """
    插入文件或文本，保存到本机 "./{collectionName}/{file.filename}" 的位置。
    支持上传文件或直接输入文本，并指定文本文件的名称。
    直接输入文本的情况，文本文件名为 textFileName，默认为 "{textFileName}.txt"。
    """
    try:
        if file is not None:
            # 上传文件的情况
            save_path = f"./{collectionName}/{file.filename}"
            os.makedirs(os.path.dirname(save_path), exist_ok=True)  # 创建所需的目录
            with open(save_path, mode='wb') as f:
                f.write(await file.read())
        elif text is not None and textFileName is not None:
            # 输入文本的情况
            save_path = f"./{collectionName}/{textFileName}.txt"
            os.makedirs(os.path.dirname(save_path), exist_ok=True)  # 创建所需的目录
            with open(save_path, mode='w', encoding='utf-8') as f:
                f.write(text)
        # else:
        #     return {"error": "No file or text provided."}

        vector_store = MilvusVectorStore(uri=URI, dim=DIMENSION, collection_name=collectionName)
        index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
        update_knowledge_base(index, save_path)
    # new_docs = SimpleDirectoryReader(input_files=[input_file]).load_data()
    # index.insert(new_docs[0])
        return {"message": f"Insert '{file.filename if file is not None else textFileName}' successfully."}
    except Exception as e:
        return {"error": str(e)}



@app.post("/delete_vectors_by_fileName/")
async def delete_vectors_by_fileName(collection_name: str, file_name: str):
    '''
    删除指定文件名的向量,同时如果存在，也会删除存在本地的文件（ 若文件名不存在，则删除同名的 .txt 文件）
    input:collection_name: str 集合名称, file_name: str 文件名
    '''
    try:
        # 构建查询表达式
        query_expr = f"file_name == '{file_name}'"
        message =  {"message": f"Deleted vectors of file_name '{file_name}' from '{collection_name}' ."}
    
        # 删除本地文件
        save_path = f"./{collection_name}/{file_name}"
        if os.path.exists(save_path):
            os.remove(save_path)
            query_expr = f"file_name == '{file_name}'"
            message = {"message": f"Deleted local file and  vectors of file_name '{file_name}' from '{collection_name}' ."}
    
        else:
            # 尝试删除同名的 .txt 文件
            txt_save_path = f"./{collection_name}/{file_name}.txt"
            if os.path.exists(txt_save_path):
                os.remove(txt_save_path)
                query_expr = f"file_name == '{file_name}.txt'"
                message =  {"message": f"Failed to find '{file_name}' in '{collection_name}', but deleted local file and the vectors of '{file_name}.txt'  instead."}
                # 连接到Milvus
        connections.connect("default", host="localhost", port="19530")
        
        # 获取集合
        collection = Collection(name=collection_name)
        
        # 删除向量
        collection.delete(expr=query_expr)
        # 返回成功信息
        return message
    except Exception as e:
        # 返回错误信息
        return {"error": str(e)}


@app.get("/delete_MilvusVectorStore/")
async def delete_MilvusVectorStore(collectionName: str):
    delete_collection(collectionName)

    return  (f"Delete MilvusVectorStore '{collectionName}' Successfully.")




if __name__ == "__main__":  
#     import uvicorn
#     uvicorn.run(app, host="0.0.0.0", port=8000)
    #uvicorn main:app --reload
    #uvicorn main:app



    '''
    collectionName = 'llamaIndexTest'
    input_dir = 'E:\\work\\vector\\txt'

    URI = "http://localhost:19530"
    # URI = "./milvus_llamaindex.db"
    # Settings.llm = BaichuanLLM(baichuan_api_key=API_KEY)


    from langchain_openai import ChatOpenAI
    llmChatOpenAI = ChatOpenAI(
    openai_api_base=url,
    openai_api_key="",
    model_name="",
    temperature=0,
    max_tokens=8000,
    )
    Settings.llm = llmChatOpenAI



    Settings.chunk_overlap = 100
    Settings.chunk_size = 600
    Settings.embed_model = BaichuanTextEmbeddings(baichuan_api_key=API_KEY)

    ##从已有mulvus向量库中加载index
    vector_store = MilvusVectorStore(uri=URI, collection_name=collectionName)
    index = get_index_from_collection(vector_store) 

    #创建mulvus collelection
    # index = create_mulivus_collection(input_dir, collectionName, URI) 
    

    #向向量库插入新的数据 
    # newfile_path = 'E:\\work\\vector\\txt\\langchain.txt'
    # newfile_path = 'E:\\work\\vector\\txt\\llamaindex.txt'

    # update_knowledge_base(index, newfile_path)


    query_question(index, 'llamaindex是什么?')
    # chatllm_query_quetion(aicubellm, index, 'langchain和llamaindex的区别?')


    # query = ''
    # new_nodes = get_retrieved_nodes(
    #     # "What is Lyft's response to COVID-19?",
    #     query,
    #     vector_top_k=6, #检索相关节点的数量
    #     reranker_top_n=3, #重新排名节点的数量
    #     with_reranker=True, #重新排名
    # )
    # visualize_retrieved_nodes(new_nodes)
    # print_index_struct(index)


    query_question(index, '')



    # logging.debug('Insert result: %s', res)
    print_entity_count(collectionName)
    # print_entity_count('LangChainCollection')

    # index_persist(index)

    # delete_collection("llamaIndexTest")
    '''
